{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import models\n",
    "from nupic.torch.modules import KWinners, KWinners2d\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "    \n",
    "    def forward(self, x): \n",
    "        return self.func(x)\n",
    "\n",
    "def Flatten():\n",
    "    return Lambda(lambda x: x.view((x.size(0), -1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class GSCHeb(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple 3 hidden layers + output MLPHeb, similar to one used in SET Paper.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config=None):\n",
    "        super(GSCHeb, self).__init__()\n",
    "\n",
    "        defaults = dict(\n",
    "            device='cpu',\n",
    "            input_size=1024,\n",
    "            num_classes=12,\n",
    "            hebbian_learning=True,\n",
    "            boost_strength=1.5,\n",
    "            boost_strength_factor=0.9,\n",
    "            k_inference_factor=1.5,\n",
    "            duty_cycle_period=1000\n",
    "        )\n",
    "        defaults.update(config or {})\n",
    "        self.__dict__.update(defaults)\n",
    "        self.device = torch.device(self.device)\n",
    "\n",
    "        # hidden layers\n",
    "        layers = [\n",
    "            *self._conv_block(1, 64, percent_on=0.095), # 28x28 -> 14x14\n",
    "            *self._conv_block(64, 64, percent_on=0.125), # 10x10 -> 5x5\n",
    "            Flatten(),\n",
    "            *self._linear_block(25*64, 1000, percent_on=0.1),\n",
    "        ]\n",
    "        # output layer\n",
    "        layers.append(nn.Linear(1000, self.num_classes))\n",
    "\n",
    "        # classifier (*redundancy on layers to facilitate traversing)\n",
    "        self.layers = layers\n",
    "        self.classifier = nn.Sequential(*layers)\n",
    "\n",
    "        # track correlations\n",
    "        self.correlations = []\n",
    "\n",
    "    def _conv_block(self, fin, fout, percent_on=0.1):\n",
    "        block = [\n",
    "            nn.Conv2d(fin, fout, kernel_size=5, stride=1, padding=0),\n",
    "            nn.BatchNorm2d(fout, affine=False),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),       \n",
    "            self._kwinners(fout, percent_on),\n",
    "        ]\n",
    "        return block\n",
    "\n",
    "    def _linear_block(self, fin, fout, percent_on=0.1):\n",
    "        block = [\n",
    "            nn.Linear(fin, fout), \n",
    "            nn.BatchNorm1d(fout, affine=False),\n",
    "            self._kwinners(fout, percent_on, twod=False),\n",
    "        ]\n",
    "        return block\n",
    "\n",
    "    def _kwinners(self, fout, percent_on, twod=True):\n",
    "        if twod:\n",
    "            activation_func = KWinners2d\n",
    "        else:\n",
    "            activation_func = KWinners\n",
    "        return activation_func(\n",
    "            fout,\n",
    "            percent_on=percent_on,\n",
    "            boost_strength=self.boost_strength,\n",
    "            boost_strength_factor=self.boost_strength_factor,\n",
    "            k_inference_factor=self.k_inference_factor,\n",
    "            duty_cycle_period=self.duty_cycle_period\n",
    "        )\n",
    "\n",
    "    def _has_activation(self, idx, layer):\n",
    "        return (\n",
    "            idx == len(self.layers) - 1\n",
    "            or isinstance(layer, KWinners)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"A faster and approximate way to track correlations\"\"\"\n",
    "        # x = x.view(-1, self.input_size)  # resiaze if needed, eg mnist\n",
    "        prev_act = (x > 0).detach().float()\n",
    "        idx_activation = 0\n",
    "        for idx_layer, layer in enumerate(self.layers):\n",
    "            # do the forward calculation normally\n",
    "            x = layer(x)\n",
    "            if self.hebbian_learning:\n",
    "                n_samples = x.shape[0]\n",
    "                if self._has_activation(idx_layer, layer):\n",
    "                    with torch.no_grad():\n",
    "                        curr_act = (x > 0).detach().float()\n",
    "                        # add outer product to the correlations, per sample\n",
    "                        for s in range(n_samples):\n",
    "                            outer = torch.ger(prev_act[s], curr_act[s])\n",
    "                            if idx_activation + 1 > len(self.correlations):\n",
    "                                self.correlations.append(outer)\n",
    "                            else:\n",
    "                                self.correlations[idx_activation] += outer\n",
    "                        # reassigning to the next\n",
    "                        prev_act = curr_act\n",
    "                        # move to next activation\n",
    "                        idx_activation += 1\n",
    "\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "vector and vector expected, got 3D, 1D tensors at ../aten/src/TH/generic/THTensorMath.cpp:886",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-50977db11cdd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGSCHeb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[1;32m     70\u001b[0m     \u001b[0;31m# make a forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m     \u001b[0;31m# print(x.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m     \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m     \u001b[0;31m# remove these hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    546\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    548\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-de48f15e580d>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     91\u001b[0m                         \u001b[0;31m# add outer product to the correlations, per sample\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m                         \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m                             \u001b[0mouter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_act\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurr_act\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m                             \u001b[0;32mif\u001b[0m \u001b[0midx_activation\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m                                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: vector and vector expected, got 3D, 1D tensors at ../aten/src/TH/generic/THTensorMath.cpp:886"
     ]
    }
   ],
   "source": [
    "model = GSCHeb()\n",
    "summary(model, (1, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
